import json
import codecs
import os
from conlleval import return_report,report_notprint,metrics,evaluate


def load_squad_predict(path):
    add = "predictions.json"
    file_path = os.path.join(path, add)
    f = codecs.open(file_path,"r",encoding="utf-8")
    data = json.load(f)
    ans_id = []
    start_index = []
    end_index = []
    for key in data:
        ans_id.append(key)

    for i in range(len(ans_id)):
        start_index.append(data[str(ans_id[i])]["start_index"])
        end_index.append(data[str(ans_id[i])]["end_index"])

    f.close()
    return ans_id,start_index,end_index

def load_chosen_test(f):
    char = []
    true_tag = []
    pre_tag = []
    char_part = []
    true_tag_part = []
    pre_tag_part = []
    for line in f:
        if len(line.strip()) != 0:
            lins_sp = line.strip().split(" ")
            char_part.append(lins_sp[0])
            true_tag_part.append(lins_sp[1])
            pre_tag_part.append(lins_sp[2])
        else:
            if len(char_part) != 0:
                char.append(char_part)
                char_part = []
            if len(true_tag_part) != 0:
                true_tag.append(true_tag_part)
                true_tag_part = []
            if len(pre_tag_part) != 0:
                pre_tag.append(pre_tag_part)
                pre_tag_part = []

    return char,true_tag,pre_tag


def merge_data(char,true_tag,pre_tag_new,f2,f3):
    for i in range(len(char)):
        char_part = char[i]
        true_tag_part = true_tag[i]
        pre_tag_part = pre_tag_new[i]
        for j in range(len(char_part)):
            f2.write(char_part[j]+" "+true_tag_part[j]+" "+pre_tag_part[j]+"\n")
        f2.write("\n")

    for line in f3:
        f2.write(line)



def generate_new_test(ans_id,start_index,end_index,path):
    add = "chosen_test.utf8"
    add_1 = "que2con.json"
    add_2 = "new_test.utf8"
    add_3 = "other_test.utf8"
    file_path = os.path.join(path,add)
    file_path_1 = os.path.join(path, add_1)
    file_path_2 = os.path.join(path, add_2)
    file_path_3 = os.path.join(path, add_3)

    f = codecs.open(file_path,"r",encoding="utf-8")
    f1 = codecs.open(file_path_1,"r",encoding="utf-8")
    f2 = codecs.open(file_path_2, "w", encoding="utf-8")
    f3 = codecs.open(file_path_3,"r",encoding="utf-8")

    que2con = json.load(f1)
    ans_id = []
    con_id = []
    for key in que2con:
        ans_id.append(key)


    for i in range(len(ans_id)):
        con_id.append(que2con[str(ans_id[i])])

    last_id = con_id[0]
    id_all = []
    as_all = []
    temp = []
    temp_as = []
    for i in range(len(con_id)):
        if last_id == con_id[i]:
            temp.append(con_id[i])

            last_id = con_id[i]
        else:
            id_all.append(temp)
            as_all.append(temp_as)
            temp = []
            temp.append(con_id[i])
            last_id = con_id[i]
    if len(temp)!=0:
        id_all.append(temp)


    char,true_tag,pre_tag = load_chosen_test(f)

    pre_tag_new = []
    for i in range(len(pre_tag)):
        pre_tag_new_part = []
        for j in range(len(pre_tag[i])):
            pre_tag_new_part.append("O")
        pre_tag_new.append(pre_tag_new_part)

    for i in range(len(con_id)):
        con_id_ = con_id[i]
        start_index_ = start_index[i]
        end_index_ = end_index[i]
        if start_index_ == end_index_:
            pre_tag_new[con_id_][start_index_] = "B-AS"
        else:
            for j in range(start_index_,end_index_+1):
                if j == start_index_:
                    pre_tag_new[con_id_][j] = "B-AS"
                else:
                    pre_tag_new[con_id_][j] = "I-AS"


    merge_data(char,true_tag,pre_tag_new,f2,f3)
    f.close()
    f1.close()
    f2.close()
    f3.close()

def pre_process(path):
    add = "chosen_test.utf8"
    add_query = "query_id.json"
    file_path = os.path.join(path, add)
    file_path_1 = os.path.join(path, add_query)
    f = open(file_path,"r",encoding="utf-8")
    f1 = open(file_path_1, "w", encoding="utf-8")
    string = []
    true_label = []
    pre_label = []

    string_part = []
    true_label_part = []
    pre_label_part = []

    for line in f:
        if len(line.strip()) != 0:
            line_sp = line.strip().split(" ")
            string_part.append(line_sp[0])
            true_label_part.append(line_sp[1])
            pre_label_part.append(line_sp[2])
        else:
            if len(string_part) != 0:
                string.append(string_part)
                string_part = []
            if len(true_label_part) != 0:
                true_label.append(true_label_part)
                true_label_part = []
            if len(pre_label_part) != 0:
                pre_label.append(pre_label_part)
                pre_label_part = []

    f.close()
    f1.close()

def evaluate_pre_now(path):
    add = "new_test.utf8"
    add1 = "predict.utf8"
    file_path = os.path.join(path,add)
    file_path1 = os.path.join(path, add1)
    eval_lines = return_report(file_path)
    eval_lines1 = return_report(file_path1)
    f1 = float(eval_lines[1].strip().split()[-1])
    f1_1 = float(eval_lines1[1].strip().split()[-1])
    print("before_repositioning:")
    print(f1_1)
    print("after_repositioning:")
    print(f1)

if __name__ == "__main__":
    reposition_result_path = "repositioning_base"
    test_result_path = "test_result"
    ans_id,start_index,end_index = load_squad_predict(reposition_result_path)
    generate_new_test(ans_id,start_index,end_index,test_result_path)
    evaluate_pre_now(test_result_path)





